# -*-Python-*-
# 128-way data-parallelism, 2-way model-parallelism on tpu 8x16

utils.run.mesh_shape = "batch:128,model:2"
utils.run.layout_rules = "batch:batch,d_ff:model,heads:model,vocab:model"
